-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add new complex.powi op #158722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add new complex.powi op #158722
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-fir-hlfir Author: Akash Banerjee (TIFitis) ChangesAdd a new Patch is 24.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158722.diff 14 Files Affected:
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 466458c05dba7..74a4e8f85c8ff 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
mlir::Value exp = args[1];
- if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
- auto realTy = complexTy.getElementType();
- mlir::Value realExp = builder.createConvert(loc, realTy, exp);
- mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
- exp =
- builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ mlir::Value result;
+ if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
+ mlir::isa<mlir::IndexType>(exp.getType())) {
+ result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
+ } else {
+ if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
+ zero);
+ }
+ result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
}
- mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index dced5f90d6924..42f5df160798c 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -61,63 +61,55 @@ void ConvertComplexPowPass::runOnOperation() {
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
- mod.walk([&](complex::PowOp op) {
+ mod.walk([&](complex::PowiOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
auto elemTy = complexTy.getElementType();
-
Value base = op.getLhs();
- Value rhs = op.getRhs();
-
- Value intExp;
- if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
- if (isZero(create.getImaginary())) {
- if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
- if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
- intExp = conv.getValue();
- }
- }
- }
-
+ Value intExp = op.getRhs();
func::FuncOp callee;
- SmallVector<Value> args;
- if (intExp) {
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
- auto funcTy = builder.getFunctionType(
- {complexTy, builder.getIntegerType(intBits)}, {complexTy});
- if (realBits == 32 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
- else if (realBits == 32 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
- else if (realBits == 64 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
- else if (realBits == 64 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
- else if (realBits == 128 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
- else if (realBits == 128 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
- else
- return;
- args = {base, intExp};
- } else {
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- auto funcTy =
- builder.getFunctionType({complexTy, complexTy}, {complexTy});
- if (realBits == 32)
- callee = getOrDeclare(builder, loc, "cpowf", funcTy);
- else if (realBits == 64)
- callee = getOrDeclare(builder, loc, "cpow", funcTy);
- else if (realBits == 128)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
- else
- return;
- args = {base, rhs};
- }
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+ auto funcTy = builder.getFunctionType(
+ {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+ if (realBits == 32 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+ else if (realBits == 32 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+ else if (realBits == 64 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+ else if (realBits == 64 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+ else if (realBits == 128 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+ else if (realBits == 128 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+ else
+ return;
+ auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+ op.replaceAllUsesWith(call.getResult(0));
+ op.erase();
+ });
- auto call = fir::CallOp::create(builder, loc, callee, args);
+ mod.walk([&](complex::PowOp op) {
+ builder.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto complexTy = cast<ComplexType>(op.getType());
+ auto elemTy = complexTy.getElementType();
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ func::FuncOp callee;
+ auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
+ if (realBits == 32)
+ callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+ else if (realBits == 64)
+ callee = getOrDeclare(builder, loc, "cpow", funcTy);
+ else if (realBits == 128)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+ else
+ return;
+ auto call =
+ fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
op.replaceAllUsesWith(call.getResult(0));
op.erase();
});
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 1fbd333db37c3..7e1691dd1587a 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = complex.pow
+! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 1827863a57f43..0b26024b02021 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 039dfd5152a06..90a9f5e03628d 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 4ee5de4d2842e..a28eaea82379b 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
complex :: a, b, c
a = b**c
end subroutine pow_test
+
+! CHECK-LABEL: func @_QPpowi_test(
+! CHECK: complex.powi
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine powi_test(a, b, c)
+ complex :: a, b
+ integer :: i
+ b = a ** i
+end subroutine powi_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 3058927144248..9f74d172a6bb2 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
! PRECISE: fir.call @_FortranAcpowi
end subroutine
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
! PRECISE: fir.call @_FortranAcpowk
end subroutine
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
! PRECISE: fir.call @_FortranAzpowi
end subroutine
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
! PRECISE: fir.call @_FortranAzpowk
end subroutine
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: fir.call @cpow
end subroutine
-
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
index d980817aba9b9..4555fea61e496 100644
--- a/flang/test/Transforms/convert-complex-pow.fir
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -2,18 +2,12 @@
module {
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
- %c0 = arith.constant 0.000000e+00 : f32
- %c1 = fir.convert %arg1 : (i32) -> f32
- %c2 = complex.create %c1, %c0 : complex<f32>
- %0 = complex.pow %arg0, %c2 : complex<f32>
+ %0 = complex.powi %arg0, %arg1 : complex<f32>, i32
return %0 : complex<f32>
}
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
- %c0 = arith.constant 0.000000e+00 : f32
- %c1 = fir.convert %arg1 : (i64) -> f32
- %c2 = complex.create %c1, %c0 : complex<f32>
- %0 = complex.pow %arg0, %c2 : complex<f32>
+ %0 = complex.powi %arg0, %arg1 : complex<f32>, i64
return %0 : complex<f32>
}
@@ -23,18 +17,12 @@ module {
}
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
- %c0 = arith.constant 0.000000e+00 : f64
- %c1 = fir.convert %arg1 : (i32) -> f64
- %c2 = complex.create %c1, %c0 : complex<f64>
- %0 = complex.pow %arg0, %c2 : complex<f64>
+ %0 = complex.powi %arg0, %arg1 : complex<f64>, i32
return %0 : complex<f64>
}
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
- %c0 = arith.constant 0.000000e+00 : f64
- %c1 = fir.convert %arg1 : (i64) -> f64
- %c2 = complex.create %c1, %c0 : complex<f64>
- %0 = complex.pow %arg0, %c2 : complex<f64>
+ %0 = complex.powi %arg0, %arg1 : complex<f64>, i64
return %0 : complex<f64>
}
@@ -44,18 +32,12 @@ module {
}
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
- %c0 = arith.constant 0.000000e+00 : f128
- %c1 = fir.convert %arg1 : (i32) -> f128
- %c2 = complex.create %c1, %c0 : complex<f128>
- %0 = complex.pow %arg0, %c2 : complex<f128>
+ %0 = complex.powi %arg0, %arg1 : complex<f128>, i32
return %0 : complex<f128>
}
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
- %c0 = arith.constant 0.000000e+00 : f128
- %c1 = fir.convert %arg1 : (i64) -> f128
- %c2 = complex.create %c1, %c0 : complex<f128>
- %0 = complex.pow %arg0, %c2 : complex<f128>
+ %0 = complex.powi %arg0, %arg1 : complex<f128>, i64
return %0 : complex<f128>
}
@@ -67,11 +49,11 @@ module {
// CHECK-LABEL: func.func @pow_c4_i4(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_i8(
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_c4(
// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32>
@@ -79,11 +61,11 @@ module {
// CHECK-LABEL: func.func @pow_c8_i4(
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_i8(
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_c8(
// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64>
@@ -91,11 +73,11 @@ module {
// CHECK-LABEL: func.func @pow_c16_i4(
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_i8(
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_c16(
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128>
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 44590406301eb..ca5103c16889c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
}];
}
+//===----------------------------------------------------------------------===//
+// PowiOp
+//===----------------------------------------------------------------------===//
+
+def PowiOp : Complex_Op<"powi",
+ [Pure, Elementwise, SameOperandsAndResultShape,
+ AllTypesMatch<["lhs", "result"]>]> {
+ let summary = "complex number raised to integer power";
+ let description = [{
+ The `powi` operation takes a complex number and an integer exponent.
+
+ Example:
+
+ ```mlir
+ %a = complex.powi %b, %c : complex<f32>, i32
+ ```
+ }];
+
+ let arguments = (ins Complex<AnyFloat>:$lhs,
+ AnySignlessInteger:$rhs);
+ let results = (outs Complex<AnyFloat>:$result);
+
+ let assemblyFormat =
+ "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
+}
+
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 0372f32d6b6df..25e5ab49cdb8c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -71,10 +73,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
return success();
}
};
+
+// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
+struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
+ using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowiOp op,
+ PatternRewriter &rewriter) const final {
+ auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
+ Type elementType = complexType.getElementType();
+
+ Type exponentType = op.getRhs().getType();
+ Type exponentFloatType = elementType;
+ if (auto shapedType = dyn_cast<ShapedType>(exponentType))
+ exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
+
+ Location loc = op.getLoc();
+ Value exponentReal =
+ rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
+ Value zeroImag = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(exponentFloatType));
+ Value exponent = rewriter.create<complex::CreateOp>(
+ loc, op.getLhs().getType(), exponentReal, zeroImag);
+
+ rewriter
+ .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+ exponent);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
@@ -125,11 +157,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
ConversionTarget target(getContext());
- target.addLegalDialect<func::FuncDialect>();
- target.addLegalOp<complex::MulOp>();
+ target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
+ target.addLegalOp<complex::CreateOp, complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
- complex::LogOp, complex::PowOp, complex::SinOp,
- complex::SqrtOp, complex::TanOp, complex::TanhOp>();
+ complex::LogOp, complex::PowOp, complex::PowiOp,
+ complex::SinOp, complex::SqrtOp, complex::TanOp,
+ complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 31785eb20a642..3711c112cc631 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
Value one;
Type opType = getElementTypeOrSelf(op.getType());
- if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+ if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
one = arith::ConstantOp::create(rewriter, loc,
rewriter.getFloatAttr(opType, 1.0));
- else
+ } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
+ ...
[truncated]
|
@llvm/pr-subscribers-mlir-complex Author: Akash Banerjee (TIFitis) ChangesAdd a new Patch is 24.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158722.diff 14 Files Affected:
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 466458c05dba7..74a4e8f85c8ff 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
mlir::Value exp = args[1];
- if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
- auto realTy = complexTy.getElementType();
- mlir::Value realExp = builder.createConvert(loc, realTy, exp);
- mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
- exp =
- builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ mlir::Value result;
+ if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
+ mlir::isa<mlir::IndexType>(exp.getType())) {
+ result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
+ } else {
+ if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
+ zero);
+ }
+ result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
}
- mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index dced5f90d6924..42f5df160798c 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -61,63 +61,55 @@ void ConvertComplexPowPass::runOnOperation() {
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
- mod.walk([&](complex::PowOp op) {
+ mod.walk([&](complex::PowiOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
auto elemTy = complexTy.getElementType();
-
Value base = op.getLhs();
- Value rhs = op.getRhs();
-
- Value intExp;
- if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
- if (isZero(create.getImaginary())) {
- if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
- if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
- intExp = conv.getValue();
- }
- }
- }
-
+ Value intExp = op.getRhs();
func::FuncOp callee;
- SmallVector<Value> args;
- if (intExp) {
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
- auto funcTy = builder.getFunctionType(
- {complexTy, builder.getIntegerType(intBits)}, {complexTy});
- if (realBits == 32 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
- else if (realBits == 32 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
- else if (realBits == 64 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
- else if (realBits == 64 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
- else if (realBits == 128 && intBits == 32)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
- else if (realBits == 128 && intBits == 64)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
- else
- return;
- args = {base, intExp};
- } else {
- unsigned realBits = cast<FloatType>(elemTy).getWidth();
- auto funcTy =
- builder.getFunctionType({complexTy, complexTy}, {complexTy});
- if (realBits == 32)
- callee = getOrDeclare(builder, loc, "cpowf", funcTy);
- else if (realBits == 64)
- callee = getOrDeclare(builder, loc, "cpow", funcTy);
- else if (realBits == 128)
- callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
- else
- return;
- args = {base, rhs};
- }
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+ auto funcTy = builder.getFunctionType(
+ {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+ if (realBits == 32 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+ else if (realBits == 32 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+ else if (realBits == 64 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+ else if (realBits == 64 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+ else if (realBits == 128 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+ else if (realBits == 128 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+ else
+ return;
+ auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+ op.replaceAllUsesWith(call.getResult(0));
+ op.erase();
+ });
- auto call = fir::CallOp::create(builder, loc, callee, args);
+ mod.walk([&](complex::PowOp op) {
+ builder.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto complexTy = cast<ComplexType>(op.getType());
+ auto elemTy = complexTy.getElementType();
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ func::FuncOp callee;
+ auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
+ if (realBits == 32)
+ callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+ else if (realBits == 64)
+ callee = getOrDeclare(builder, loc, "cpow", funcTy);
+ else if (realBits == 128)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+ else
+ return;
+ auto call =
+ fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
op.replaceAllUsesWith(call.getResult(0));
op.erase();
});
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 1fbd333db37c3..7e1691dd1587a 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = complex.pow
+! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 1827863a57f43..0b26024b02021 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 039dfd5152a06..90a9f5e03628d 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 4ee5de4d2842e..a28eaea82379b 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
complex :: a, b, c
a = b**c
end subroutine pow_test
+
+! CHECK-LABEL: func @_QPpowi_test(
+! CHECK: complex.powi
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine powi_test(a, b, c)
+ complex :: a, b
+ integer :: i
+ b = a ** i
+end subroutine powi_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 3058927144248..9f74d172a6bb2 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
! PRECISE: fir.call @_FortranAcpowi
end subroutine
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
! PRECISE: fir.call @_FortranAcpowk
end subroutine
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
! PRECISE: fir.call @_FortranAzpowi
end subroutine
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: complex.pow
+ ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
! PRECISE: fir.call @_FortranAzpowk
end subroutine
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: fir.call @cpow
end subroutine
-
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
index d980817aba9b9..4555fea61e496 100644
--- a/flang/test/Transforms/convert-complex-pow.fir
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -2,18 +2,12 @@
module {
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
- %c0 = arith.constant 0.000000e+00 : f32
- %c1 = fir.convert %arg1 : (i32) -> f32
- %c2 = complex.create %c1, %c0 : complex<f32>
- %0 = complex.pow %arg0, %c2 : complex<f32>
+ %0 = complex.powi %arg0, %arg1 : complex<f32>, i32
return %0 : complex<f32>
}
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
- %c0 = arith.constant 0.000000e+00 : f32
- %c1 = fir.convert %arg1 : (i64) -> f32
- %c2 = complex.create %c1, %c0 : complex<f32>
- %0 = complex.pow %arg0, %c2 : complex<f32>
+ %0 = complex.powi %arg0, %arg1 : complex<f32>, i64
return %0 : complex<f32>
}
@@ -23,18 +17,12 @@ module {
}
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
- %c0 = arith.constant 0.000000e+00 : f64
- %c1 = fir.convert %arg1 : (i32) -> f64
- %c2 = complex.create %c1, %c0 : complex<f64>
- %0 = complex.pow %arg0, %c2 : complex<f64>
+ %0 = complex.powi %arg0, %arg1 : complex<f64>, i32
return %0 : complex<f64>
}
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
- %c0 = arith.constant 0.000000e+00 : f64
- %c1 = fir.convert %arg1 : (i64) -> f64
- %c2 = complex.create %c1, %c0 : complex<f64>
- %0 = complex.pow %arg0, %c2 : complex<f64>
+ %0 = complex.powi %arg0, %arg1 : complex<f64>, i64
return %0 : complex<f64>
}
@@ -44,18 +32,12 @@ module {
}
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
- %c0 = arith.constant 0.000000e+00 : f128
- %c1 = fir.convert %arg1 : (i32) -> f128
- %c2 = complex.create %c1, %c0 : complex<f128>
- %0 = complex.pow %arg0, %c2 : complex<f128>
+ %0 = complex.powi %arg0, %arg1 : complex<f128>, i32
return %0 : complex<f128>
}
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
- %c0 = arith.constant 0.000000e+00 : f128
- %c1 = fir.convert %arg1 : (i64) -> f128
- %c2 = complex.create %c1, %c0 : complex<f128>
- %0 = complex.pow %arg0, %c2 : complex<f128>
+ %0 = complex.powi %arg0, %arg1 : complex<f128>, i64
return %0 : complex<f128>
}
@@ -67,11 +49,11 @@ module {
// CHECK-LABEL: func.func @pow_c4_i4(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_i8(
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c4_c4(
// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32>
@@ -79,11 +61,11 @@ module {
// CHECK-LABEL: func.func @pow_c8_i4(
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_i8(
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c8_c8(
// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64>
@@ -91,11 +73,11 @@ module {
// CHECK-LABEL: func.func @pow_c16_i4(
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_i8(
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
// CHECK-LABEL: func.func @pow_c16_c16(
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128>
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 44590406301eb..ca5103c16889c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
}];
}
+//===----------------------------------------------------------------------===//
+// PowiOp
+//===----------------------------------------------------------------------===//
+
+def PowiOp : Complex_Op<"powi",
+ [Pure, Elementwise, SameOperandsAndResultShape,
+ AllTypesMatch<["lhs", "result"]>]> {
+ let summary = "complex number raised to integer power";
+ let description = [{
+ The `powi` operation takes a complex number and an integer exponent.
+
+ Example:
+
+ ```mlir
+ %a = complex.powi %b, %c : complex<f32>, i32
+ ```
+ }];
+
+ let arguments = (ins Complex<AnyFloat>:$lhs,
+ AnySignlessInteger:$rhs);
+ let results = (outs Complex<AnyFloat>:$result);
+
+ let assemblyFormat =
+ "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
+}
+
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 0372f32d6b6df..25e5ab49cdb8c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -71,10 +73,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
return success();
}
};
+
+// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
+struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
+ using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowiOp op,
+ PatternRewriter &rewriter) const final {
+ auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
+ Type elementType = complexType.getElementType();
+
+ Type exponentType = op.getRhs().getType();
+ Type exponentFloatType = elementType;
+ if (auto shapedType = dyn_cast<ShapedType>(exponentType))
+ exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
+
+ Location loc = op.getLoc();
+ Value exponentReal =
+ rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
+ Value zeroImag = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(exponentFloatType));
+ Value exponent = rewriter.create<complex::CreateOp>(
+ loc, op.getLhs().getType(), exponentReal, zeroImag);
+
+ rewriter
+ .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+ exponent);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
@@ -125,11 +157,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
ConversionTarget target(getContext());
- target.addLegalDialect<func::FuncDialect>();
- target.addLegalOp<complex::MulOp>();
+ target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
+ target.addLegalOp<complex::CreateOp, complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
- complex::LogOp, complex::PowOp, complex::SinOp,
- complex::SqrtOp, complex::TanOp, complex::TanhOp>();
+ complex::LogOp, complex::PowOp, complex::PowiOp,
+ complex::SinOp, complex::SqrtOp, complex::TanOp,
+ complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 31785eb20a642..3711c112cc631 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
Value one;
Type opType = getElementTypeOrSelf(op.getType());
- if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+ if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
one = arith::ConstantOp::create(rewriter, loc,
rewriter.getFloatAttr(opType, 1.0));
- else
+ } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
+ ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a new complex.powi
operation to MLIR's complex dialect for computing complex numbers raised to integer powers. The implementation provides more efficient handling of complex-to-integer power operations compared to the generic complex power operation.
Key changes include:
- Addition of the new
PowiOp
operation definition in the Complex dialect - Integration with algebraic simplification passes for optimization
- Support for conversion to ROCDL library calls
- Updates to Flang frontend to generate the new operation
Reviewed Changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td | Defines the new PowiOp operation with complex base and integer exponent |
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp | Adds algebraic simplification patterns for complex.powi operations |
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp | Implements conversion from complex.powi to ROCDL library calls |
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | Updates complex power conversion to handle both powi and pow operations |
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | Modifies intrinsic call generation to use powi for integer exponents |
Multiple test files | Updates test expectations to reflect the new powi operation usage |
auto imagPart = rewriter.getFloatAttr(elementType, 0.0); | ||
one = rewriter.create<complex::ConstantOp>( | ||
loc, complexTy, rewriter.getArrayAttr({realPart, imagPart})); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The nested if-else chain with constexpr conditions could be simplified using if constexpr for all branches to improve readability and consistency.
} else { | |
} else if constexpr (true) { |
Copilot uses AI. Check for mistakes.
patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext()); | ||
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider reordering these pattern additions to maintain alphabetical order for better code organization and maintainability.
Copilot uses AI. Check for mistakes.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
Just a couple of minor comments.
What is the LLVM lowering here? |
AllTypesMatch<["lhs", "result"]>]> { | ||
let summary = "complex number raised to integer power"; | ||
let description = [{ | ||
The `powi` operation takes a complex number and an integer exponent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signed aspects of the operands, as well at the overflow (or other special) behaviors should be specified here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the description to be consistent with other similar Powi ops. Let me know if I'm missing anything.
Good point. We can try to reuse the conversion for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once the existing comments are addressed.
45a1331
to
6faad70
Compare
@joker-eph @vzakhari At the moment both |
That isn't in MLIR right now, so that's not generally usable.
I'm confused: are these the same op? I would assume the semantics differs of your wouldn't add a new op. So how com you can just convert one to the other? |
You can always convert The introduction of |
I've added |
Thanks, LG! |
}); | ||
|
||
auto call = fir::CallOp::create(builder, loc, callee, args); | ||
mod.walk([&](complex::PowOp op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not walk multiple times if we can do it in a single traversal, can you replace this with a walk on Operation* and dispatch inside the walk?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The powi
part looks good to me. Are you planning to merge it, and then rebase the other PR for the Flang changes for the final review?
I plan on landing both PRs at once. This PR depends on #158642, which should land first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some final comments.
if constexpr (std::is_same_v<T, mlir::complex::PowOp>) { | ||
auto resultType = mathLibFuncType.getResult(0); | ||
result = T::create(builder, loc, resultType, args); | ||
} else if constexpr (std::is_same_v<T, mlir::complex::PowiOp>) { | ||
auto resultType = mathLibFuncType.getResult(0); | ||
auto fmfAttr = mlir::arith::FastMathFlagsAttr::get( | ||
builder.getContext(), builder.getFastMathFlags()); | ||
result = builder.create<mlir::complex::PowiOp>(loc, resultType, args[0], | ||
args[1], fmfAttr); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need all this code? I believe just a simple T::create(buider, loc, args)
should work, because of the type constraints in the operations definitions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I've simplified it. Thanks for catching.
Type elementType = complexTy.getElementType(); | ||
auto realPart = rewriter.getFloatAttr(elementType, 1.0); | ||
auto imagPart = rewriter.getFloatAttr(elementType, 0.0); | ||
one = rewriter.create<complex::ConstantOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe all the create
methods of the rewriter will become deprecated soon, so complex::ConstantOp::create
is a better alternative. There are other cases below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This PR introduces a new `ConvertComplexPow` pass for Flang that handles complex power operations. The change forces lowering to complex.pow operations when `--math-runtime=precise` is not used, then uses the `ConvertComplexPow` pass to convert these operations back to library calls. - Adds a new `ConvertComplexPow` pass that converts complex.pow ops to appropriate runtime library calls - Updates complex power lowering to use `complex.pow` operations by default instead of direct library calls #158722 Adds a new `complex.powi` op enabling algebraic optimisations.
6dbb370
to
78d9190
Compare
This PR adds a new complex.powi operation to MLIR's complex dialect for computing complex numbers raised to integer powers.
Key changes include:
PowiOp
operation definition in the Complex dialectThis depends on #158642.